import torch

x = torch.tensor([10, 20, 30, 40, 50])
perm_idx = torch.tensor([1, 2, 0, 3, 4])
x_sorted = x[perm_idx]  # [20, 30, 10, 40, 50]

inv_idx = torch.argsort(perm_idx)  # [2, 0, 1, 3, 4]
x_back = x_sorted[inv_idx]  # == x
assert torch.equal(x_back, x)
